//
//  crop.metal
//  EffectMgrMetal
//
//  Created by WS on 2021/5/26.
//  Copyright © 2021 WS. All rights reserved.
//

#include <metal_stdlib>
using namespace metal;

#define CLAMP(v, min, max) \
    if (v < min) { \
        v = min; \
    } else if (v > max) { \
        v = max; \
    }

static float4 GetPixelClamped(texture2d<float, access::read> in [[texture(0)]], uint x, uint y, int inW, int inH) {
    CLAMP(x, 0, inW - 1)
    CLAMP(y, 0, inH - 1)
    return in.read(uint2(x, y));
}

static float Lerp (float A, float B, float t) {
    return A * (1.0f - t) + B * t;
}

static float4 SampleBilinear (texture2d<float, access::read> in [[texture(0)]],
                       float u, float v, int inW, int inH) {
    // calculate coordinates -> also need to offset by half a pixel to keep image from shifting down and left half a pixel
    float x = u * float(inW) - 0.5f;
    int xint = int(x);
    float xfract = x - floor(x);
    
    float y = v * float(inH) - 0.5f;
    int yint = int(y);
    float yfract = y - floor(y);
    
    // get pixels
    auto p00 = GetPixelClamped(in, xint + 0, yint + 0, inW, inH);
    auto p10 = GetPixelClamped(in, xint + 1, yint + 0, inW, inH);
    auto p01 = GetPixelClamped(in, xint + 0, yint + 1, inW, inH);
    auto p11 = GetPixelClamped(in, xint + 1, yint + 1, inW, inH);
    
    // interpolate bi-linearly!
    float4 ret;
    for (int i = 0; i < 4; ++i)
    {
        float col0 = Lerp(p00[i], p10[i], xfract);
        float col1 = Lerp(p01[i], p11[i], xfract);
        float value = Lerp(col0, col1, yfract);
        CLAMP(value, 0.0f, 255.0f);
        ret[i] = value;
    }
    return ret;
}

kernel void resize(texture2d<float, access::read> in [[texture(0)]],
                        texture2d<float, access::write> out [[texture(1)]],
                        constant int *inW [[buffer(0)]],
                        constant int *inH [[buffer(1)]],
                        constant int *outW [[buffer(2)]],
                        constant int *outH [[buffer(3)]],
                        constant int *resize_type[[buffer(4)]],
                        uint2 gid [[thread_position_in_grid]])
{
    float u = float(gid.x) / float(*outW - 1);
    float v = float(gid.y) / float(*outH - 1);
    float2 tc = float2(u,v);

    float f_srcWidth = float(*inW);
    float f_srcHeight = float(*inH);
    float f_dstWidth = float(*outW);
    float f_dstHeight = float(*outH);
   

    float scalFactorX = f_dstWidth / f_srcWidth;
	float scalFactorY = f_dstHeight / f_srcHeight;
    float resizeCoord_x = tc.x;
	float resizeCoord_y = tc.y;
    float matt = 1.0;
    if((*resize_type) == 1){
        float src_ratio = f_srcWidth / f_srcHeight;
        float dst_ratio = f_dstWidth / f_dstHeight;
        float tmpDstH = ceil(f_dstWidth * f_srcHeight / f_srcWidth);
        if(f_dstHeight >= tmpDstH) {
            float dstH = tmpDstH / f_dstHeight ;
            float roiY0 = (1.0 - dstH) * 0.5;
            float roiY1 = roiY0 + dstH;
            matt = step(roiY0 ,tc.y) * step(tc.y , roiY1);
            resizeCoord_y = (tc.y - roiY0 ) / dstH;
        }else{
            float tmpDstW = ceil(f_dstHeight * src_ratio);
            float dstW =  tmpDstW / f_dstWidth;
            float roiX0 = (1.0 - dstW) * 0.5;
            float roiX1 = roiX0 + dstW;
            matt = step(roiX0 ,tc.x) * step(tc.x , roiX1);
            resizeCoord_x = (tc.x - roiX0) / dstW;
        }
    }
	if((*resize_type) == 2)
	{
		float temp_srcHeight = f_srcHeight;
		float temp_srcWidth  = f_srcWidth;
		if(f_srcWidth * 9 > f_srcHeight * 16)
		{
			f_srcHeight = ceil(f_srcWidth * 9 / 16);
		}
		else
		{
			f_srcWidth = ceil(f_srcHeight * 16 / 9);
		}
		float tempWidth =  f_srcWidth  * f_dstHeight;
		float tempHeight = f_srcHeight * f_dstWidth;
		if(tempHeight >= tempWidth)
		{
			float w = tempWidth / f_srcHeight;
			float dstW = w / f_dstWidth;
			float roiX0 = (1.0 - dstW) * 0.5;
			float roiX1 = roiX0 + dstW;
			matt = step(roiX0,tc.x) * step(tc.x,roiX1);
			resizeCoord_x = (tc.x - roiX0) / dstW;
		}
		else
		{
			float h = tempHeight / f_srcWidth;
			float dstH = h / f_dstHeight;
			float roiY0 = (1.0 - dstH) * 0.5;
			float roiY1 = roiY0  + dstH;
			matt = step(roiY0, tc.y) * step(tc.y, roiY1);
			resizeCoord_y = (tc.y - roiY0 ) / dstH;
		}
	}
	if((*resize_type) == 3)
	{
		float temp_srcHeight = f_srcHeight;
		float temp_srcWidth  = f_srcWidth;
		if(f_srcWidth * 3 > f_srcHeight * 4)
		{
			f_srcHeight = f_srcWidth * 3 / 4;
		}
		else
		{
			f_srcWidth = f_srcHeight * 4 / 3;
		}
		float tempWidth =  f_srcWidth  * f_dstHeight;
		float tempHeight = f_srcHeight * f_dstWidth;
		if(tempHeight >= tempWidth)
		{
			float w = tempWidth / f_srcHeight;
			float dstW = w / f_dstWidth;
			float roiX0 = (1.0 - dstW) * 0.5;
			float roiX1 = roiX0 + dstW;
			matt = step(roiX0,tc.x) * step(tc.x,roiX1);
			resizeCoord_x = (tc.x - roiX0) / dstW;
		}
		else
		{
			float h = tempHeight / f_srcWidth;
			float dstH = h / f_dstHeight;
			float roiY0 = (1.0 - dstH) * 0.5;
			float roiY1 = roiY0  + dstH;
			matt = step(roiY0, tc.y) * step(tc.y, roiY1);
			resizeCoord_y = (tc.y - roiY0 ) / dstH;
		}
	}
	if((*resize_type) == 4)
	{
		if(f_srcWidth * 16 > f_srcHeight * 9)
		{
			f_srcWidth = ceil(f_srcHeight * 9 / 16);
		}
		else
		{
			f_srcHeight = ceil(f_srcWidth * 16 / 9);
		}
		float tempWidth  =  f_srcWidth  * f_dstHeight;
		float tempHeight =  f_srcHeight  * f_dstWidth;
		if(tempHeight >= tempWidth)
		{
			float w = tempWidth / f_srcHeight;
			float dstW = w / f_dstWidth;
			float roiX0 = (1.0 - dstW) * 0.5;
			float roiX1 = roiX0 + dstW;
			matt = step(roiX0,tc.x) * step(tc.x,roiX1);
			resizeCoord_x = (tc.x - roiX0) / dstW;
		}
		else
		{
			float h = tempHeight / f_srcWidth;
			float dstH = h / f_dstHeight;
			float roiY0 = (1.0 - dstH) * 0.5;
			float roiY1 = roiY0  + dstH;
			matt = step(roiY0, tc.y) * step(tc.y, roiY1);
			resizeCoord_y = (tc.y - roiY0 ) / dstH;
		}
	}
	if((*resize_type) == 5)
	{
		float temp_srcHeight = f_srcHeight;
		float temp_srcWidth  = f_srcWidth;
		if(f_srcWidth >= f_srcHeight)
		{
			f_srcWidth = f_srcHeight;
		}
		else
		{
			f_srcHeight = f_srcWidth;
		}
		float tempWidth  =  f_srcWidth  * f_dstHeight;
		float tempHeight =  f_srcHeight  * f_dstWidth;
		if(tempHeight >= tempWidth)
		{
			float w = tempWidth / f_srcHeight;
			float dstW = w / f_dstWidth;
			float roiX0 = (1.0 - dstW) * 0.5;
			float roiX1 = roiX0 + dstW;
			matt = step(roiX0,tc.x) * step(tc.x,roiX1);
			resizeCoord_x = (tc.x - roiX0) / dstW;
		}
		else
		{
			float h = tempHeight / f_srcWidth;
			float dstH = h / f_dstHeight;
			float roiY0 = (1.0 - dstH) * 0.5;
			float roiY1 = roiY0  + dstH;
			matt = step(roiY0, tc.y) * step(tc.y, roiY1);
			resizeCoord_y = (tc.y - roiY0 ) / dstH;
		}
	}
    float2 uv = float2(resizeCoord_x,resizeCoord_y);

	float4 ovlCol = SampleBilinear(in, uv.x, uv.y, *inW, *inH) * matt;

    out.write(ovlCol, gid);
}
